Resnet18-迁移学习——CIFAR-10图像分类

        迁移学习在计算机视觉任务和自然语言处理任务中经常会用到,并且使用迁移学习,可将预训练模型左为新的模型起点,从而节省时间,提高效率。

        一、特征提取:可以在预先训练好的网络结构后,添加或者修改一个简单的分类器,将源任务上预先训练好的网络模型作为另一个目标任务的特征提取器,只对最后增加的分类器参数进行重新学习,而预先训练好的参数,不会被修改或者冻结。

        CIFAR-10数据集由10个类的60000个32*32彩色图像组成,每个类由6000个图像。其中由50000个训练图像和10000个测试图像组成。本次实例,将采用预训练的resnet18网络,来完成图像分类的任务。

"""
特征提取的实例:
利用迁移学习中特征提取的方法来对CIFAR-10数据集实现对10类无体的分类
预训练模型采用resnet18网络
"""
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torchvision.datasets import ImageFolder
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def main():
    # 加载和预处理数据集
    trans_train = transforms.Compose(
        [transforms.RandomResizedCrop(224),  # 将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为制定的大小;
         # (即先随机采集,然后对裁剪得到的图像缩放为同一大小) 默认scale=(0.08, 1.0)
         transforms.RandomHorizontalFlip(),  # 以给定的概率随机水平旋转给定的PIL的图像,默认为0.5;
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])])

    trans_valid = transforms.Compose(
        [transforms.Resize(256),  # 是按照比例把图像最小的一个边长放缩到256,另一边按照相同比例放缩。
         transforms.CenterCrop(224),  # 依据给定的size从中心裁剪
         transforms.ToTensor(),  # 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1]
         # 归一化至[0-1]是直接除以255,若自己的ndarray数据尺度有变化,则需要自行修改。
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])])  # 对数据按通道进行标准化,即先减均值,再除以标准差,注意是 hwc

    trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
                                            download=False, transform=trans_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=256,
                                              shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='../data', train=False,
                                           download=False, transform=trans_valid)
    testloader = torch.utils.data.DataLoader(testset, batch_size=256,
                                             shuffle=False, num_workers=2)

    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    # 随机获取部分训练数据
    dataiter = iter(trainloader)
    images, labels = dataiter.next()

    # 显示图像
    imshow(torchvision.utils.make_grid(images[:4]))
    # 打印标签
    print(''.join('%5s ' % classes[labels[j]] for j in range(4)))

    # 使用预训练模型
    resnet = models.resnet18(pretrained=True)

    # 冻结模型参数
    for param in resnet.parameters():
        param.requires_grad = False

    # 修改最后的全连接层改为十分类
    resnet.fc = nn.Linear(512, 10)

    # 查看总参数及训练参数
    total_params = sum(p.numel() for p in resnet.parameters())
    print('总参数个数:{}'.format(total_params))
    total_trainable_params = sum(p.numel() for p in resnet.parameters() if p.requires_grad)
    print('需训练参数个数:{}'.format(total_trainable_params))

    resnet = resnet.to(device)
    criterion = nn.CrossEntropyLoss()  # 损失函数
    # 只需要优化最后一层参数
    optimizer = torch.optim.SGD(resnet.fc.parameters(), lr=1e-3, weight_decay=1e-3, momentum=0.9)  # 优化器

    # train
    train(resnet, trainloader, testloader, 30, optimizer, criterion)


# 计算准确率
def get_acc(output, label):
    total = output.shape[0]
    _, pred_label = output.max(1)
    num_correct = (pred_label == label).sum().item()
    return num_correct / total


# 显示图片
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# 定义训练函数
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
    prev_time = datetime.now()
    for epoch in range(num_epochs):
        train_loss = 0
        train_acc = 0
        net = net.train()
        for im, label in train_data:
            im = im.to(device)  # (bs, 3, h, w)
            label = label.to(device)  # (bs, h, w)
            # forward
            output = net(im)
            loss = criterion(output, label)
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_acc += get_acc(output, label)

        cur_time = datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        if valid_data is not None:
            valid_loss = 0
            valid_acc = 0
            net = net.eval()
            for im, label in valid_data:
                im = im.to(device)  # (bs, 3, h, w)
                label = label.to(device)  # (bs, h, w)
                output = net(im)
                loss = criterion(output, label)
                valid_loss += loss.item()
                valid_acc += get_acc(output, label)
            epoch_str = (
                    "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
                    % (epoch, train_loss / len(train_data),
                       train_acc / len(train_data), valid_loss / len(valid_data),
                       valid_acc / len(valid_data)))
        else:
            epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
                         (epoch, train_loss / len(train_data),
                          train_acc / len(train_data)))
        prev_time = cur_time
        print(epoch_str + time_str)


if __name__ == '__main__':
    main()

训练结果:本次训练了30次,测试集的精度可以达到:75%

二、数据增强

        提高模型的泛化能力的最重要的3大因素时数据、模型和损失函数,其中数据又是3个中最重要的。数据增强技术可以扩大数据量,在图像识别、语言识别等中可以通过水平或者竖直翻转图像,裁剪,色彩变幻,扩展和旋转等数据增强技术。

下面我们介绍一种微调的实例:

在同样的数据集中,我们通过数据增强的数据预处理,来微调网络。

trans_train = transforms.Compose(
        [transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),  # 将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪
         # 得到的图像为制定的大小; (即先随机采集,然后对裁剪得到的图像缩放为同一大小) 默认scale=(0.8, 1.0)
         transforms.RandomRotation(degrees=15),  # 随机旋转函数-依degrees随机旋转一定角度
         transforms.ColorJitter(),  # 改变颜色的,随机从-0.5 0.5之间对颜色变化
         transforms.RandomResizedCrop(224),  # 随机长宽比裁剪
         transforms.RandomHorizontalFlip(),  # 以给定的概率随机水平旋转给定的PIL的图像,默认为0.5;
         transforms.ToTensor(),  # 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1]
         # 归一化至[0-1]是直接除以255,若自己的ndarray数据尺度有变化,则需要自行修改。
         transforms.Normalize(mean=[0.485, 0.456, 0.406],   # 对数据按通道进行标准化,即先减均值,再除以标准差,注意是 hwc
                              std=[0.229, 0.224, 0.225])])

    trans_valid = transforms.Compose(
        [transforms.Resize(256),  # 是按照比例把图像最小的一个边长放缩到256,另一边按照相同比例放缩。
         transforms.CenterCrop(224),  # 依据给定的size从中心裁剪
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])])

微调之后的训练结果:将接近95%,并且如果增加训练次数,验证集的准确率将还会提高。

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值