Mixup数据增强

Mixup数据增强

(1)Mixup方法是将随机的两张样本按比例混合,分类的结果按比例分配;

(2)Cutout方法是随机的将样本中的部分区域cut掉,并且填充0像素值,分类的结果不变;

(3)CutMix方法是将一部分区域cut掉但不填充0像素而是随机填充训练集中的其他数据的区域像素值,分类结果按一定的比例分配。

论文标题:mixup: Beyond empirical risk minimization(ICLR 2018)

abstract:
Large deep neural networks are powerful, but exhibit undesirable
behaviors such as memorization and sensitivity to adversarial
examples. In this work, we propose mixup, a simple learning principle
to alleviate these issues. In essence, mixup trains a neural network
on convex combinations of pairs of examples and their labels. By doing
so, mixup regularizes the neural network to favor simple linear
behavior in-between training examples. Our experiments on the
ImageNet-2012, CIFAR-10, CIFAR-100, Google commands and UCI datasets
show that mixup improves the generalization of state-of-the-art neural
network architectures. We also find that mixup reduces the
memorization of corrupt labels, increases the robustness to
adversarial examples, and stabilizes the training of generative
adversarial networks

开源代码:https://github.com/facebookresearch/mixup-cifar10
在这里插入图片描述

mixup对两个样本-标签数据对按比例相加后生成新的样本-标签数据,即主要是将两张图像按比例进行混合。

Mixup数据增强实现上的核心思想是从每个Batch中随机选择两张图片,并以一定比例混合生成新的图像,训练过程全部采用混合的新图像训练,原始图像不再参与训练。

计算服从Beta分布的随机数,取值为0~1,代码实现如下:

r = np.random.beta(8.0, 8.0) 

图像1和图像2按照比例进行融合

img = (img1 * r + img2 * (1 - r)).astype(np.uint8)

将两个图像的标签信息拼接到一起,不需要对标签坐标进行调整

labels = np.concatenate((labels1, labels2), 0)
def mixup(data, targets1, alpha=1.0):
    # 随机索引序列
    indices = torch.randperm(data.size(0))
    # 打乱后的图像数据
    shuffled_data = data[indices]
    # 随机标签列表
    shuffled_targets1 = targets1[indices]
    # 确定beta系数
    lam = np.random.beta(alpha, alpha)
    # 图像融合
    data = data * lam + shuffled_data * (1 - lam)
    # 结果
    targets = [targets1, shuffled_targets1, lam]
    return data, targets
def mixup(data, targets1, targets2, targets3, alpha):
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets1 = targets1[indices]
    shuffled_targets2 = targets2[indices]
    shuffled_targets3 = targets3[indices]

    lam = np.random.beta(alpha, alpha)
    data = data * lam + shuffled_data * (1 - lam)
    targets = [targets1, shuffled_targets1, targets2, shuffled_targets2, targets3, shuffled_targets3, lam]

    return data, targets
def mixup_criterion(preds1,preds2,preds3, targets):
    targets1, targets2,targets3, targets4,targets5, targets6, lam = targets[0], targets[1], targets[2], targets[3], targets[4], targets[5], targets[6]
    criterion = nn.CrossEntropyLoss(reduction='mean')
    return lam * criterion(preds1, targets1) + (1 - lam) * criterion(preds1, targets2) + lam * criterion(preds2, targets3) + (1 - lam) * criterion(preds2, targets4) + lam * criterion(preds3, targets5) + (1 - lam) * criterion(preds3, targets6)

Mixup
1.三个或以上的样本进行mixup并不会带来更多的收益。
2.模型容量越大或者训练时间越长mixup带来的收益越多。

#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree.
from __future__ import print_function

import argparse
import csv
import os

import numpy as np
import torch
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import models
from utils import progress_bar

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true',
                    help='resume from checkpoint')
parser.add_argument('--model', default="ResNet18", type=str,
                    help='model type (default: ResNet18)')
parser.add_argument('--name', default='0', type=str, help='name of run')
parser.add_argument('--seed', default=0, type=int, help='random seed')
parser.add_argument('--batch-size', default=128, type=int, help='batch size')
parser.add_argument('--epoch', default=200, type=int,
                    help='total epochs to run')
parser.add_argument('--no-augment', dest='augment', action='store_false',
                    help='use standard augmentation (default: True)')
parser.add_argument('--decay', default=1e-4, type=float, help='weight decay')
parser.add_argument('--alpha', default=1., type=float,
                    help='mixup interpolation coefficient (default: 1)')
args = parser.parse_args()

use_cuda = torch.cuda.is_available()

best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

if args.seed != 0:
    torch.manual_seed(args.seed)

# Data
print('==> Preparing data..')
if args.augment:
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
else:
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])


transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR10(root='~/data', train=True, download=False,
                            transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=args.batch_size,
                                          shuffle=True, num_workers=8)

testset = datasets.CIFAR10(root='~/data', train=False, download=False,
                           transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=False, num_workers=8)


# Model
if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.t7' + args.name + '_'
                            + str(args.seed))
    net = checkpoint['net']
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch'] + 1
    rng_state = checkpoint['rng_state']
    torch.set_rng_state(rng_state)
else:
    print('==> Building model..')
    net = models.__dict__[args.model]()

if not os.path.isdir('results'):
    os.mkdir('results')
logname = ('results/log_' + net.__class__.__name__ + '_' + args.name + '_'
           + str(args.seed) + '.csv')

if use_cuda:
    net.cuda()
    net = torch.nn.DataParallel(net)
    print(torch.cuda.device_count())
    cudnn.benchmark = True
    print('Using CUDA..')

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9,
                      weight_decay=args.decay)


def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        # torch.randperm(n):将0~n-1(包括0和n-1)随机打乱后获得的数字序列,函数名是random permutation缩写
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    reg_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()

        inputs, targets_a, targets_b, lam = mixup_data(inputs, targets,
                                                       args.alpha, use_cuda)
        inputs, targets_a, targets_b = map(Variable, (inputs,
                                                      targets_a, targets_b))
        outputs = net(inputs)
        loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
        train_loss += loss.data[0]
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (lam * predicted.eq(targets_a.data).cpu().sum().float()
                    + (1 - lam) * predicted.eq(targets_b.data).cpu().sum().float())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        progress_bar(batch_idx, len(trainloader),
                     'Loss: %.3f | Reg: %.5f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), reg_loss/(batch_idx+1),
                        100.*correct/total, correct, total))
    return (train_loss/batch_idx, reg_loss/batch_idx, 100.*correct/total)


def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(testloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = Variable(inputs, volatile=True), Variable(targets)
        outputs = net(inputs)
        loss = criterion(outputs, targets)

        test_loss += loss.data[0]
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

        progress_bar(batch_idx, len(testloader),
                     'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (test_loss/(batch_idx+1), 100.*correct/total,
                        correct, total))
    acc = 100.*correct/total
    if epoch == start_epoch + args.epoch - 1 or acc > best_acc:
        checkpoint(acc, epoch)
    if acc > best_acc:
        best_acc = acc
    return (test_loss/batch_idx, 100.*correct/total)


def checkpoint(acc, epoch):
    # Save checkpoint.
    print('Saving..')
    state = {
        'net': net,
        'acc': acc,
        'epoch': epoch,
        'rng_state': torch.get_rng_state()
    }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    torch.save(state, './checkpoint/ckpt.t7' + args.name + '_'
               + str(args.seed))


def adjust_learning_rate(optimizer, epoch):
    """decrease the learning rate at 100 and 150 epoch"""
    lr = args.lr
    if epoch >= 100:
        lr /= 10
    if epoch >= 150:
        lr /= 10
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


if not os.path.exists(logname):
    with open(logname, 'w') as logfile:
        logwriter = csv.writer(logfile, delimiter=',')
        logwriter.writerow(['epoch', 'train loss', 'reg loss', 'train acc',
                            'test loss', 'test acc'])

for epoch in range(start_epoch, args.epoch):
    train_loss, reg_loss, train_acc = train(epoch)
    test_loss, test_acc = test(epoch)
    adjust_learning_rate(optimizer, epoch)
    with open(logname, 'a') as logfile:
        logwriter = csv.writer(logfile, delimiter=',')
        logwriter.writerow([epoch, train_loss, reg_loss, train_acc, test_loss,
                            test_acc])

CutMix数据增强
CutMix把一张图片上的某个随机矩形区域剪裁到另一张图片上生成新图片。标签的处理和mixUp是一样的,都是按照新样本中两个原样本的比例确定新的混合标签的比例,这种处理更适合图像中信息连续性这个特点


https://www.bilibili.com/read/cv23511556

# 随机截取
def rand_bbox(size, lam):
    # 宽
    W = size[2]
    # 高
    H = size[3]
    # 裁切比例
    cut_rat = np.sqrt(1. - lam)
    # 裁切宽
    cut_w = np.int(W * cut_rat)
    # 裁切高
    cut_h = np.int(H * cut_rat)

    # 裁切中心点x
    cx = np.random.randint(W)
    # 裁切中心点y
    cy = np.random.randint(H)
    # 左上角点x
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    # 左上角点y
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    # 右下角点x
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    # 右下角点y
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    # 返回裁切区域的左上角点和右下角点坐标
    return bbx1, bby1, bbx2, bby2
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2
    
def cutmix(data, targets1, targets2, targets3, alpha):
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets1 = targets1[indices]
    shuffled_targets2 = targets2[indices]
    shuffled_targets3 = targets3[indices]

    lam = np.random.beta(alpha, alpha)
    bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
    data[:, :, bbx1:bbx2, bby1:bby2] = data[indices, :, bbx1:bbx2, bby1:bby2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2]))

    targets = [targets1, shuffled_targets1, targets2, shuffled_targets2, targets3, shuffled_targets3, lam]
    return data, targets

def mixup(data, targets1, targets2, targets3, alpha):
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets1 = targets1[indices]
    shuffled_targets2 = targets2[indices]
    shuffled_targets3 = targets3[indices]

    lam = np.random.beta(alpha, alpha)
    data = data * lam + shuffled_data * (1 - lam)
    targets = [targets1, shuffled_targets1, targets2, shuffled_targets2, targets3, shuffled_targets3, lam]

    return data, targets


def cutmix_criterion(preds1,preds2,preds3, targets):
    targets1, targets2,targets3, targets4,targets5, targets6, lam = targets[0], targets[1], targets[2], targets[3], targets[4], targets[5], targets[6]
    criterion = nn.CrossEntropyLoss(reduction='mean')
    return lam * criterion(preds1, targets1) + (1 - lam) * criterion(preds1, targets2) + lam * criterion(preds2, targets3) + (1 - lam) * criterion(preds2, targets4) + lam * criterion(preds3, targets5) + (1 - lam) * criterion(preds3, targets6)

def mixup_criterion(preds1,preds2,preds3, targets):
    targets1, targets2,targets3, targets4,targets5, targets6, lam = targets[0], targets[1], targets[2], targets[3], targets[4], targets[5], targets[6]
    criterion = nn.CrossEntropyLoss(reduction='mean')
    return lam * criterion(preds1, targets1) + (1 - lam) * criterion(preds1, targets2) + lam * criterion(preds2, targets3) + (1 - lam) * criterion(preds2, targets4) + lam * criterion(preds3, targets5) + (1 - lam) * criterion(preds3, targets6)

for i, (image_id, images, label1, label2, label3) in enumerate(data_loader_train):
            images = images.to(device)
            label1 = label1.to(device)
            label2 = label2.to(device)
            label3 = label3.to(device)
            # print (image_id, label1, label2, label3)

            if np.random.rand()<0.5:
                images, targets = mixup(images, label1, label2, label3, 0.4)
                output1, output2, output3 = model(images)
                loss = mixup_criterion(output1,output2,output3, targets) 
            else:
                images, targets = cutmix(images, label1, label2, label3, 0.4)
                output1, output2, output3 = model(images)
                loss = cutmix_criterion(output1,output2,output3, targets) 

MixUp训练实例

# ============================ 导入工具包包 ============================
from torchvision import transforms,models
from torch.utils.data import Dataset,DataLoader
import pandas as pd
from PIL import Image
import torch
import os
import copy
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import numpy as np
import torch.nn as nn
'''
ETA:
1.图像大一不一致
2.图像都为正方形

'''
def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        # torch.randperm(n):将0~n-1(包括0和n-1)随机打乱后获得的数字序列
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam



def mixup(data, targets, alpha=1.0):
    # # 确定beta系数 如果alpha小于0,则lam设为1
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    # 随机索引序列 shape 24
    indices = torch.randperm(data.size(0)) # data.size(0) 是batch_size
    # 打乱后的图像数据  shape 24 3 324 324
    shuffled_data = data[indices]
    # 随机标签列表 shape 24
    shuffled_targets = targets[indices]

    # 图像融合
    mix_data = data * lam + shuffled_data * (1 - lam)
    # 结果
    return mix_data, targets, shuffled_targets, lam



def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)



# ============================ 辅助函数 ============================
def find_classes(fulldir):
    # 获取所有文件夹名称
    classes = os.listdir(fulldir)
    classes.sort()
    # 类名-id 字典
    class_to_idx = dict(zip(classes, range(len(classes))))
    # id-类名 字典
    idx_to_class = {v: k for k, v in class_to_idx.items()}
    train = []

    for i, label in idx_to_class.items():
        path = fulldir + "/" + label
        for file in os.listdir(path):
            train.append([f'{label}/{file}', label, i])
    # 图像路径 标签 标签id
    df = pd.DataFrame(train, columns=["file", "class", "class_index"])
    return classes, class_to_idx, idx_to_class, df



# ============================ step 0/5 参数设置 ============================
# 训练轮次
num_epoch = 10
# 批大小
batch_size = 24
# 多线程读取数据
num_workers = 0
# 设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 训练集路径
root_dir = "./plant-seedlings-classification/train"
# 获取相关信息
classes, class_to_idx, idx_to_class, df = find_classes(root_dir)
# 类别个数
num_classes = len(classes)


# ============================ step 1/5 数据 ============================
class PlantDataset(Dataset):
    def __init__(self, dataframe, root_dir, transform=None):
        self.transform = transform
        self.df = dataframe
        self.root_dir = root_dir
        #self.classes =
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        fullpath = os.path.join(self.root_dir, self.df.iloc[idx][0])
        image = Image.open(fullpath).convert('RGB')
        image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
        image = image.astype(np.float32)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            # 方式1:transform
            # image = self.transform(image)
            # 方式2:albumentations
            image = self.transform(image=image)['image']
        return image, self.df.iloc[idx][2]

# 训练数据预处理
# 方式1:transforms
# train_transform = transforms.Compose([
#     transforms.RandomRotation(180),
#     transforms.RandomAffine(degrees = 0, translate = (0.3, 0.3)),
#     #transforms.CenterCrop(384),
#     transforms.Resize((324,324)),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomVerticalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])

# 方式2:albumentations
train_transform =  A.Compose([
                               A.Resize(324, 324),
                               # A.RandomRotate90(),
                               # A.RandomCrop(256, 256),
                               # A.HorizontalFlip(p=0.5),
                               # A.VerticalFlip(p=0.5),
                               A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                               ToTensorV2(),
                           ])


# 验证数据预处理
# 方式1:transforms
# val_transform = transforms.Compose([
#     # transforms.CenterCrop(384),
#     transforms.Resize((324, 324)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])

# 方式2:albumentations
val_transform = A.Compose([
                                A.Resize(324, 324),
                                # A.RandomRotate90(),
                                # A.RandomCrop(256, 256),
                                # A.HorizontalFlip(p=0.5),
                                # A.VerticalFlip(p=0.5),
                                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                                ToTensorV2(),
                            ])

# 划分验证集和训练集
X_train, X_val = train_test_split(df,test_size=0.2, random_state=42,stratify=df['class_index'])
# 构建训练集的Dataset和DataLoader
train_dataset = PlantDataset(X_train,root_dir, train_transform)
train_loader = DataLoader(train_dataset, batch_size = batch_size, num_workers= num_workers,shuffle=True, drop_last=True )

# 构建验证集的Dataset和DataLoader
val_dataset = PlantDataset(X_val,root_dir, val_transform)
val_loader = DataLoader(val_dataset, batch_size = batch_size, num_workers= num_workers, drop_last=True )

# ============================ step 2/5 模型 ============================

'''
模型系列一:resnet
'''

model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

# 禁止梯度
# param.requires_grad = False 不影响误差反向传播的正常进行,但是权重和偏置值不更新了。
# 用法:冻结参数,不参与反向传播,具体实现是将要冻结的参数的requires_grad属性置为false,然后在优化器初始化时将参数组进行筛选,只加入requires_grad为True的参数
for param in model.parameters():
    param.requires_grad = False


# resnet网络最后一层分类层fc是对1000种类型进行划分,对于自己的数据集,如果只有10类,则需要修改最后的fc分类器层
# in_features表示线性层的输入大小,fc = nn.Linear(512, 10),fc.in_features表示512
num_ftrs = model.fc.in_features
# 改换fc层
model.fc = torch.nn.Sequential(
    torch.nn.Linear(num_ftrs, 256),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.4),
    torch.nn.Linear(256, num_classes)
)

'''
模型系列二:efficientnet
'''
# model = models.efficientnet_b2(True)
model = models.efficientnet_b2(weights=models.EfficientNet_B2_Weights.DEFAULT)

# 冻结参数
for param in model.parameters():
    param.requires_grad = False

# model.avgpool = nn.AdaptiveAvgPool2d(1)
model.classifier = nn.Sequential(
    torch.nn.Linear(1408, 256),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.4),
    torch.nn.Linear(256, num_classes)
)

model = model.to(device)

# ============================ step 3/5 损失函数 ============================
# 交叉熵损失函数
criterion = torch.nn.CrossEntropyLoss()


# ============================ step 4/5 优化器 ============================
# 优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

## 优化策略
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.7)

# ============================ step 5/5 训练 ============================
def train_model(model, criterion, optimizer, scheduler, num_epochs=10, device=device):
    # since = time.time()
    model.to(device)
    # 深拷贝模型权重
    best_model_wts = copy.deepcopy(model.state_dict())
    # 初始化最优准确率
    best_acc = 0.0
    # 训练
    for epoch in range(num_epochs):
        for phase in ["train", "val"]:
            if phase == "train":
                # 设置为训练模式,可以进行权重更新
                model.train()
                # 初始化损失和准确率
                train_loss = 0.0
                train_acc = 0
                for index,(image, label) in enumerate(train_loader):
                    # 图像
                    image = image.to(device)
                    # 标签
                    label = label.to(device)
                    # mixup 数据增强
                    inputs, targets_a, targets_b, lam = mixup(image, label,alpha=1.0)
                    # 前向传播
                    outputs = model(inputs)
                    # 计算损失
                    # loss = criterion(y_pred, label)
                    loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
                    # 打印当前损失
                    print("Epoch",num_epochs,r'/',epoch,"Iteration",len(train_loader),r'/',index,"loss:",loss.item())
                    # 累计损失
                    train_loss += loss.item()
                    # 梯度清零
                    optimizer.zero_grad()
                    # 反向传播
                    loss.backward()
                    # 模型参数更新
                    optimizer.step()
                    # 获取预测类别
                    y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
                    # 计算准确率
                    train_acc += (y_pred_class == label).sum().item() / len(outputs)
                # 更新权重更新策略
                scheduler.step()
                # 计算平均损失
                train_loss /= len(train_loader)
                # 计算平均准确率
                train_acc /= len(train_loader)
                # 打印损失和准确率
                print("train_loss:",train_loss,"train_acc",train_acc)
            # 验证
            else:
                # 模型设置为评估模式,禁止权重更新
                model.eval()
                # 初始化验证损失和准确率
                test_loss, test_acc = 0, 0
                # torch.inference_mode:禁用梯度,加快推理速度
                with torch.inference_mode():
                    # 验证
                    for image, label in val_loader:
                        # 获取图像和标签
                        image = image.to(device)
                        label = label.to(device)
                        # 前向推理
                        test_pred_logits = model(image)
                        # 计算损失
                        loss = criterion(test_pred_logits, label)
                        # 累加损失
                        test_loss += loss.item()
                        # 获取预测类别,
                        test_pred_labels = test_pred_logits.argmax(dim=1) # argmax返回指定维度最大值的序号
                        # 计算准确率
                        test_acc += ((test_pred_labels == label).sum().item() / len(test_pred_labels))
                # 计算验证平均损失
                test_loss = test_loss / len(val_loader)
                # 计算验证平均准确率
                test_acc = test_acc / len(val_loader)
                # 如果平均准确率大于之前记录最优的准确率,则取出权重,后面进行保存
                if test_acc > best_acc:
                    best_acc = test_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    # 保存最优模型
                    torch.save(model, "best_model.pt")
        # 打印当前轮信息
        print(
            f"Epoch: {epoch + 1} | "
            f"train_loss: {train_loss:.4f} | "
            f"train_acc: {train_acc:.4f} | "
            f"test_loss: {test_loss:.4f} | "
            f"test_acc: {test_acc:.4f}"
        )

    # 训练结束后,返回最优模型
    model.load_state_dict(best_model_wts)
    return model

# 训练模型,获取最优模型
model_ft = train_model(model, criterion, optimizer,exp_lr_scheduler,num_epochs=70)

# 保存最优模型
torch.save(model_ft, "best_model.pt")



  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
下面是一个简单的 mixup 数据增强的代码实现: ```python import torch import numpy as np def mixup_data(x, y, alpha=1.0): '''使用 mixup 数据增强''' # 生成 lambda 分布的随机权重 lam = np.random.beta(alpha, alpha) batch_size = x.size()[0] # 生成随机索引 index = torch.randperm(batch_size) # 将数据混合 mixed_x = lam * x + (1 - lam) * x[index, :] mixed_y = lam * y + (1 - lam) * y[index, :] return mixed_x, mixed_y def mixup_criterion(criterion, pred, y, lam): '''计算 mixup 数据增强后的损失函数''' loss = lam * criterion(pred, y) + (1 - lam) * criterion(pred, y[index, :]) return loss ``` 其中,`mixup_data` 函数接收输入数据 `x` 和标签 `y`,并使用 mixup 数据增强方法对它们进行混合。这个函数首先生成一个 lambda 分布的随机权重 `lam`,然后在数据和标签上应用这个权重,将它们混合在一起。 `mixup_criterion` 函数是为了计算使用 mixup 数据增强后的损失函数。这个函数接收原始损失函数 `criterion`、预测值 `pred`、标签 `y` 和混合权重 `lam`,并返回一个加权的损失值。 使用方法如下: ```python # 加载数据 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # 训练模型 for epoch in range(num_epochs): for i, (x, y) in enumerate(train_loader): # 使用 mixup 数据增强 mixed_x, mixed_y = mixup_data(x, y, alpha=1.0) # 向模型输入数据并进行训练 output = model(mixed_x) loss = mixup_criterion(criterion, output, mixed_y, lam) optimizer.zero_grad() loss.backward() optimizer.step() ``` 在训练模型的循环中,我们首先使用 `mixup_data` 函数对输入数据和标签进行混合,然后将混合后的数据输入到模型中进行训练。在计算损失值时,我们使用 `mixup_criterion` 函数来计算加权的损失值。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值