如何利用Pytorch针对自己所设计的数据集进行的简单迁移学习

本博客的内容是讲解新手如何利用Pytorch针对自己所设计的数据集进行简单的迁移学习。

笔者在网上找了一幅图,能够很形象的说明迁移学习的含义,如下:

 

以VGG16为Backbone,CIFAR10为数据集,AdamW为梯度下降策略,ReduceLROnPlateau为学习调整机制。

注意:显卡是2060,电脑是拯救者;VGG16网络便对此进行了改进(img_size为64*64*3)!也就是说这个小Demo自己的电脑也可以跑,不用在服务器下运行。

GitHub代码:

SimpleTransferLearning-Pytorch-masterhttps://github.com/HanXiaoyiGitHub/SimpleTransferLearning-Pytorch-master

文件结构

D:
|
|
|
|----data|----CIFAR10
         |----pet(my own designed dataset, if you need this dataset, you can contact me by CSDN or Zhihu.)
       
D:
|
|
|
|----PycharmProject----SimpleTransferLearning-Pytorch-master
                            |----tensorboard(args.tensorboard=True, visualization loss)
                            |----log(classification_log)
                            |----checkpoints(save model pretrained:CIFAR10_vgg16.pth, the model for training my own dataset: pets_vgg16.pth)
                            |----models
                            |       |----__init__.py
                            |       |----vgg16.py
                            |
                            |----tool----classification
                            |                  |----train.py
                            |----utils
                                   |----get_logger.py(log)
                                   |----path.py(path)
                                   |----AverageMeter.py(AP)
                                   |----accuracy.py

迁移学习之前

网络模型代码

未进行迁移学习之前的vgg16,我们要用这个vgg16网络预训练CIFAR10数据集,之后再做迁移学习

vgg16.py(对image_size进行改进64*64*3

from torch import nn


class vgg16(nn.Module):
    def __init__(self, num_classes=1000):
        super(vgg16, self).__init__()
        self.layer1 = nn.Sequential(
            # 1
            # 64*64*3 -> 64*64*64
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            # 2
            # 64*64*64 -> 32*32*64
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.layer2 = nn.Sequential(
            # 3
            # 32*32*64 -> 32*32*128
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            # 4
            # 32*32*128 -> 16*16*128
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.layer3 = nn.Sequential(
            # 5
            # 16*16*128 -> 16*16*256
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            # 6
            # 16*16*256 -> 16*16*256
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            # 7
            # 16*16*256 -> 8*8*256
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.layer4 = nn.Sequential(
            # 8
            # 8*8*256 -> 8*8*512
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # 9
            # 8*8*512 -> 8*8*512
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # 10
            # 8*8*512 -> 4*4*512
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.layer5 = nn.Sequential(
            # 11
            # 4*4*512 -> 4*4*512
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # 12
            # 4*4*512 -> 4*4*512
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # 13
            # 4*4*512 -> 2*2*512
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.conv = nn.Sequential(
            self.layer1,
            self.layer2,
            self.layer3,
            self.layer4,
            self.layer5,
        )

        self.fc = nn.Sequential(
            # 14
            nn.Flatten(),
            nn.Linear(2 * 2 * 512, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),

            # 15
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
        )
        # 16
        self.classifier = nn.Linear(4096, num_classes)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.classifier(x)
        return x

训练代码

train.py

import os
import time
import logging
import argparse
import warnings

warnings.filterwarnings('ignore')

import torch
import torchvision
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from models.vgg16 import vgg16

from utils.path import *
from utils.accuracy import accuracy
from utils.get_logger import get_logger
from utils.AverageMeter import AverageMeter


def parse_args():
    parser = argparse.ArgumentParser(description='PyTorch Detection Training')
    parser.add_mutually_exclusive_group()
    parser.add_argument('--dataset',
                        type=str,
                        default='CIFAR10',
                        help=' CIFAR10')
    parser.add_argument('--dataset_root',
                        type=str,
                        default=CIFAR_path,
                        help='Dataset root directory path')
    parser.add_argument('--basenet',
                        type=str,
                        default='vgg',
                        help='Pretrained base model')
    parser.add_argument('--depth',
                        type=int,
                        default=16,
                        help='Basenet depth')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='Batch size for training')
    parser.add_argument('--resume',
                        type=str,
                        default=None,
                        help='Checkpoint state_dict file to resume training from')
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='Number of workers user in dataloading')
    parser.add_argument('--cuda',
                        type=str,
                        default=True,
                        help='Use CUDA to train model')
    parser.add_argument('--save_folder',
                        type=str,
                        default=CheckPoints,
                        help='Directory for saving checkpoint models')
    parser.add_argument('--tensorboard',
                        type=str,
                        default=False,
                        help='Use tensorboard for loss visualization')
    parser.add_argument('--log_folder',
                        type=str,
                        default=log,
                        help='Log Folder')
    parser.add_argument('--log_name',
                        type=str,
                        default=classification_train_log,
                        help='Log Name')
    parser.add_argument('--tensorboard_log',
                        type=str,
                        default=tensorboard_log,
                        help='Use tensorboard for loss visualization')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-3,
                        help='learning rate')
    parser.add_argument('--epochs',
                        type=int,
                        default=20,
                        help='Number of epochs')
    parser.add_argument('--num_classes',
                        type=int,
                        default=10,
                        help='the number classes')
    parser.add_argument('--image_size',
                        type=int,
                        default=64,
                        help='image size')
    parser.add_argument('--accumulation_steps',
                        type=int,
                        default=1,
                        help='Gradient acumulation steps', )

    return parser.parse_args()


args = parse_args()

# 1. Torch choose cuda or cpu
if torch.cuda.is_available():
    if args.cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    if not args.cuda:
        print("WARNING: It looks like you have a CUDA device, but you aren't using it" +
              "\n You can set the parameter of cuda to True.")
        torch.set_default_tensor_type('torch.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')

if os.path.exists(args.save_folder) is None:
    os.mkdir(args.save_folder)

# 2. Log
get_logger(args.log_folder, args.log_name)
logger = logging.getLogger(args.log_name)


def train():
    # 3. Create SummaryWriter
    if args.tensorboard:
        from torch.utils.tensorboard import SummaryWriter
        # tensorboard  loss
        writer = SummaryWriter(args.tensorboard_log)

    # 4. Ready dataset
    if args.dataset == 'CIFAR10':
        if args.dataset_root != CIFAR_path:
            raise ValueError('Must specify dataset_root if specifying dataset CIFAR10.')

        elif args.dataset_root is None:
            raise ValueError("Must provide --dataset_root when training on CIFAR10.")

        dataset = torchvision.datasets.CIFAR10(root=args.dataset_root, train=True,
                                               transform=torchvision.transforms.Compose([
                                                   transforms.Resize((args.image_size,
                                                                      args.image_size)),
                                                   torchvision.transforms.ToTensor()]))

    dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch_size,
                                             shuffle=True, num_workers=args.num_workers,
                                             pin_memory=False, generator=torch.Generator(device='cuda:0'))
    top1 = AverageMeter()
    top5 = AverageMeter()
    losses = AverageMeter()

    # 5. Define train model
    if args.basenet == 'vgg':
        if args.depth == 16:
            model = vgg16(num_classes=args.num_classes)
        else:
            raise ValueError("Unsupported model depth!")
    else:
        raise ValueError('Unsupported model type!')

    if args.cuda:
        if torch.cuda.is_available():
            model = model.cuda()
            model = torch.nn.DataParallel(model).cuda()
    else:
        model = torch.nn.DataParallel(model)

    # 6. Loading weights
    if args.resume:
        other, ext = os.path.splitext(args.resume)
        if ext == '.pkl' or '.pth':
            print('Loading weights into state dict...')
            model_load = os.path.join(args.save_folder, args.resume)
            model.load_state_dict(torch.load(model_load))
        else:
            print('Sorry only .pth and .pkl files supported.')
    elif args.resume is None:
        print("Initializing weights...")
        # initialize newly added models' weights with xavier method
        model.apply(weights_init)

    model.train()

    iteration = 0

    # 7. Optimizer
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss()

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=3, verbose=True)

    # 8. Length
    iter_size = len(dataset) // args.batch_size
    print("len(dataset): {}, iter_size: {}".format(len(dataset), iter_size))
    logger.info(f"args - {args}")
    t0 = time.time()

    # 9. Create batch iterator
    for epoch in range(args.epochs):
        t1 = time.time()
        torch.cuda.empty_cache()
        # 10. Load train data
        for data in dataloader:
            iteration += 1
            images, targets = data
            if args.cuda:
                images, targets = images.cuda(), targets.cuda()

            optimizer.zero_grad()
            # 11. Forward
            outputs = model(images)

            if args.cuda:
                criterion = criterion.cuda()

            loss = criterion(outputs, targets)
            loss = loss / args.accumulation_steps

            # 12. Backward
            loss.backward()
            optimizer.step()

            # 13. Measure accuracy and record loss
            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))
            losses.update(loss.item(), images.size(0))

            if args.tensorboard:
                writer.add_scalar("train_classification_loss", loss.item(), iteration)

            if iteration % 100 == 0:
                logger.info(

                    f"- epoch: {epoch},  iteration: {iteration}, lr:{optimizer.param_groups[0]['lr']}, "
                    f"top1 acc: {acc1.item():.2f}%, top5 acc: {acc5.item():.2f}%, "
                    f"loss: {loss.item():.3f}, (losses.avg): {losses.avg:3f} "
                )

        scheduler.step(losses.avg)

        t2 = time.time()
        h_time = (t2 - t1) // 3600
        m_time = ((t2 - t1) % 3600) // 60
        s_time = ((t2 - t1) % 3600) % 60
        print("epoch {} is finished, and the time is {}h{}min{}s".format(epoch, int(h_time), int(m_time), int(s_time)))

        # 14. Save train model
        if epoch != 0 and epoch % 10 == 0:
            print('Saving state, iter:', epoch)
            torch.save(model.state_dict(),
                       args.save_folder + '/' + args.dataset +
                       '_' + args.basenet + '_' + repr(epoch) + '.pth')
        torch.save(model.state_dict(),
                   args.save_folder + '/' + args.dataset + "_" + args.basenet + '.pth')

    if args.tensorboard:
        writer.close()

    t3 = time.time()
    h = (t3 - t0) // 3600
    m = ((t3 - t0) % 3600) // 60
    s = ((t3 - t0) % 3600) % 60
    print("The Finished Time is {}h{}m{}s".format(int(h), int(m), int(s)))
    return top1.avg, top5.avg, losses.avg


def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias, 0.0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight)
        nn.init.constant_(m.bias, 0.0)


if __name__ == '__main__':
    torch.multiprocessing.set_start_method('spawn')
    logger.info("Program started")
    top1, top5, loss = train()
    print("top1 acc: {}, top5 acc: {}, loss:{}".format(top1, top5, loss))
    logger.info("Done!")

工具包代码

utils/get_logger.py

import logging
import os
import os.path


def get_logger(log_dir, log_name):
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # 1. create logger
    logger = logging.getLogger(log_name)
    logger.setLevel(logging.INFO)

    # 2. create log
    log_name = os.path.join(log_dir, '{}.info.log'.format(log_name))

    # 3. setting output formatter
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    # 4. log file processor
    fh = logging.FileHandler(log_name)
    fh.setFormatter(formatter)

    # 5. setting screen stdout output processor
    sh = logging.StreamHandler(stream=None)
    sh.setFormatter(formatter)

    # 6. add the processor to the logger
    logger.addHandler(fh)  # add file
    logger.addHandler(sh)  # add sh

    return logger

utils/AverageMeter.py

class AverageMeter:
    '''Computes and stores the average and current value'''

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        return self.avg

utils/path.py

import os.path
import sys
BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)

# Gets home dir cross platform
# "D:\pycharmProject"
HOME = BASE_DIR
MyName = "TransferLearning-Pytorch-master"

# Path to store checkpoint model
CheckPoints = 'checkpoints'
CheckPoints = os.path.join(HOME, MyName, CheckPoints)

# Results
Results = 'results'
Results = os.path.join(HOME, MyName, Results)

# Path to store tensorboard load
tensorboard_log = 'tensorboard'
tensorboard_log = os.path.join(HOME, MyName, tensorboard_log)

# Path to save log
log = 'log'
log = os.path.join(HOME, MyName, log)

# Path to save classification train log
classification_train_log = 'classification_train'

# Path to save classification test log
classification_test_log = 'classification_test'

# Path to save classification eval log
classification_eval_log = 'classification_eval'

# Images path
image = '000001.jpg'
images_path = 'images/classcification'
images_path = os.path.join(HOME, MyName, images_path, image)

# Data
DATAPATH = os.path.dirname(BASE_DIR)

# CIFAR10

CIFAR_path = os.path.join(DATAPATH, 'data', 'cifar')

utils/accuracy.py

import torch


def accuracy(output, target, topk=(1,)):
    '''Computes the accuracy over the k top predictions for the specified values of k'''
    with torch.no_grad():
        maxk = max(topk)

        # Total element
        batch_size = target.size(0)

        # The  topk function selects the top k number of output
        # values means element
        # pred means index
        values, pred = output.topk(maxk, 1, True, True)

        # Transpose
        pred = pred.t()

        # correct: tensor([[True, True, False, False],
        #                  [False, False, True, True],
        #                  [False, False, False, False]])
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            # correct_k means the total number of elements meeting requirements
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            # [] %
            res.append(correct_k.mul(100.0 / batch_size))
        return res


if __name__ == '__main__':
    outputs = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
                            [0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
                            [-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
                            [0.2021, 0.3041, 0.1383, 0.3849, -1.6311]])
    target = torch.tensor([[4], [4], [2], [1]])
    res = accuracy(outputs, target, topk=(1, 3))
    print("res: ", res)

预训练结果

 进行迁移学习

自己数据集-猫狗数据集(猫150张图像\狗150张图像)

所在位置

D:\data\pet

 

 

 

数据集的定义可以参考我之前写过的博客

1.初识Pytorch之Datasethttps://blog.csdn.net/XiaoyYidiaodiao/article/details/121960837

若需要数据集请下方留言

发现此数据集不是特别好用,于是把此数据集改的格式改为只含有图像和标签

首先操作猫,将猫从”cat“里取出,放到“pet”里,并将其改名为偶数的图像,再将空“cat”包删除

ReadFileName.py

import os

filePath = "D:\\data\\pet"
old_names = os.path.join(filePath, "cat")
old_names = os.listdir(old_names)
for idx, name in enumerate(old_names):
    new_name = os.path.join(filePath, str((idx+1) * 2) + ".jpg")
    old_name = os.path.join(filePath, "cat", name)
    print("idx: {} | old_name:{} | new_name:{} ".format(idx, old_name, new_name))
    os.rename(old_name, new_name)

读取文件名后进行标注(cat标0,dog标1)存入label.txt中

WriteLabel.py

import os

filePath = "D:\\data\\pet"
label = os.path.join(filePath, "label.txt")
old_names = os.path.join(filePath)
old_names = os.listdir(old_names)
for idx, name in enumerate(old_names):
    other, ext = os.path.splitext(name)
    if ext == '.jpg':
        print("idx:{} | name: {} ".format(idx, name))
        with open(label, "a") as f:
            f.write("image_name : " + name + " , category_id : 0 , category : cat;\n")

之后操作狗,与之前的操作方式相同。

将猫从”dog“里取出,放到“pet”里,并将其改名为奇数的图像,再将空“dog”包删除

ReadFileName.py(有所改变)

import os

filePath = "D:\\data\\pet"
old_names = os.path.join(filePath, "dog")
old_names = os.listdir(old_names)
for idx, name in enumerate(old_names):
    num = (idx + 1) * 2 - 1
    new_name = os.path.join(filePath, str(num) + ".jpg")
    old_name = os.path.join(filePath, "dog", name)
    print("idx: {} | old_name:{} | new_name:{} ".format(idx, old_name, new_name))
    os.rename(old_name, new_name)

 

 

 WriteLabel.py

import os

filePath = "D:\\data\\pet"
label = os.path.join(filePath, "label.txt")
old_names = os.path.join(filePath)
old_names = os.listdir(old_names)
for idx, name in enumerate(old_names):
    other, ext = os.path.splitext(name)
    if ext == '.jpg':
        if int(other) % 2 == 1:
            print("idx:{} | name: {} ".format(idx, name))
            with open(label, "a") as f:
                f.write("image_name : " + name + " , category_id : 1 , category : dog;\n")

 

 

将label.txt整理成(只为好看)

 又将数据集整理成图像与标签格式

工具包更新

加入dataloader.py

utils\dataloader.py

import os
from PIL import Image
from .path import PETS_path
from torch.utils.data import Dataset


class pets(Dataset):
    def __init__(self, root_dir, label_dir='annots', transform=None):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.transform = transform
        self.img_path = os.path.join(self.root_dir, 'imgs')
        self.img_path = os.listdir(self.img_path)
        self.label_dir = os.path.join(self.root_dir, self.label_dir, 'label.txt')

    def __getitem__(self, idx):
        annot = self.load_annot(idx)
        img = self.load_image(idx)
        if self.transform:
            img = self.transform(img)

        return img, annot

    def load_annot(self, idx):
        with open(self.label_dir, "r") as f:
            for line in f.readlines():
                img_name = self.img_path[idx]
                line = line.strip()
                if img_name in line:
                    # image_name : 101.jpg , category_id : 1 , category : dog;
                    # line[37] = index "1"
                    annot = line[37]
        return annot

    def load_image(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir, 'imgs', img_name)
        img = Image.open(img_item_path)
        return img

    def __len__(self):
        return len(self.img_path)


if __name__ == '__main__':
    root_dir = PETS_path
    dataset = pets(root_dir=root_dir)
    img, annot = dataset[3]
    print("img.shape: {}, annot: {}".format(img, annot))

 

utils\path.py

.....
(之前代码后加入)
# mydataset pets
PETS_path = os.path.join(DATAPATH, 'data', 'pet')

 工具包加入collate.py

utils\collate.py

import torch
import numpy as np


def collate(data):
    imgs = []
    annots = []
    for d in data:
        img, annot = d
        imgs.append(img)
        annots.append(int(annot))
    # all imgs have the same height and width
    # img.shape = 3, height,width
    height = imgs[0].shape[1]
    width = imgs[0].shape[2]
    batch_size = len(imgs)
    padded_imgs = np.zeros((batch_size, 3, height, width),
                           dtype=np.float32)
    for i, img in enumerate(imgs):
        padded_imgs[i, :, :, :] = img
    padded_imgs = torch.from_numpy(padded_imgs)
    targets = np.zeros(batch_size)
    for i, annot in enumerate(annots):
        targets[i] = int(annot)
    annots = torch.from_numpy(targets)
    return padded_imgs, annots

模型修改

申请vgg模型就是vgg16的模型

models\vgg16.py

删除之前的最后一层卷积训练的权重

del pretrained_models['module.classifier.7.bias']

完整代码

import torch
from torch import nn
from utils.path import CheckPoints

__all__ = [
    'vgg16',
]
model_urls = {
    'vgg16': '{}/CIFAR10_vgg.pth'.format(CheckPoints)
}


def _vgg(arch, num_classes, pretrained, progress, **kwargs):
    model = vgg(num_classes=num_classes, **kwargs)
    pretrained_models = torch.load(model_urls["vgg" + arch])
    del pretrained_models['module.classifier.7.bias']

    if pretrained:
        model.load_state_dict(pretrained_models, strict=False)
    return model


def vgg16(num_classes, pretrained=False, progress=True, **kwargs):
    return _vgg('16', num_classes, pretrained, progress, **kwargs)


class vgg(nn.Module):
    def __init__(self, num_classes=2):
        super(vgg, self).__init__()
        self.layer1 = nn.Sequential(
            # 1
            # 64*64*3 -> 64*64*64
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            # 2
            # 64*64*64 -> 32*32*64
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.layer2 = nn.Sequential(
            # 3
            # 32*32*64 -> 32*32*128
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            # 4
            # 32*32*128 -> 16*16*128
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.layer3 = nn.Sequential(
            # 5
            # 16*16*128 -> 16*16*256
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            # 6
            # 16*16*256 -> 16*16*256
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            # 7
            # 16*16*256 -> 8*8*256
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.layer4 = nn.Sequential(
            # 8
            # 8*8*256 -> 8*8*512
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # 9
            # 8*8*512 -> 8*8*512
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # 10
            # 8*8*512 -> 4*4*512
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.layer5 = nn.Sequential(
            # 11
            # 4*4*512 -> 4*4*512
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # 12
            # 4*4*512 -> 4*4*512
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # 13
            # 4*4*512 -> 2*2*512
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.conv = nn.Sequential(
            self.layer1,
            self.layer2,
            self.layer3,
            self.layer4,
            self.layer5,
        )

        self.fc = nn.Sequential(
            # 14
            nn.Flatten(),
            nn.Linear(2 * 2 * 512, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),

            # 15
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
        )
        # 16
        self.classifier = nn.Linear(4096, num_classes)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.classifier(x)
        return x

训练代码修改

1.在args.resume加入刚才训练的权重

parser.add_argument('--pretrained',
                        type=str,
                        default='CIFAR10_vgg.pth',
                        help='Checkpoint state_dict file to resume training from')

2.因为数据集较少,将epochs增大

  parser.add_argument('--epochs',
                        type=int,
                        default=50,
                        help='Number of epochs')

3.数据集获取

   parser.add_argument('--dataset',
                        type=str,
                        default='pets',
                        choices=['CIFAR10', 'pets'],
                        help=' CIFAR10 or pets')
    parser.add_argument('--dataset_root',
                        type=str,
                        default=PETS_path,
                        choices=[CIFAR_path, PETS_path],
                        help='Dataset root directory path')

 if args.dataset == 'CIFAR10':
        if args.dataset_root != CIFAR_path:
            raise ValueError('Must specify dataset_root if specifying dataset CIFAR10.')

        elif args.dataset_root is None:
            raise ValueError("Must provide --dataset_root when training on CIFAR10.")

        dataset = torchvision.datasets.CIFAR10(root=args.dataset_root, train=True,
                                               transform=torchvision.transforms.Compose([
                                                   transforms.Resize((args.image_size,
                                                                      args.image_size)),
                                                   torchvision.transforms.ToTensor()]))
    elif args.dataset == 'pets':
        if args.dataset_root != PETS_path:
            raise ValueError('Must specify dataset_root if specifying dataset PETS.')

        elif args.dataset_root is None:
            raise ValueError("Must provide --dataset_root when training on PETS.")

        dataset = pets(root_dir=args.dataset_root, transform=transforms.Compose([
            transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor()]))

    dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size=args.batch_size,
                                             num_workers=args.num_workers,
                                             shuffle=True,
                                             collate_fn=collate)

4.因为数据集比较小300多张图片,iter设置为1打印一次,并且只有猫狗两个类别,没有了top5 acc

            if iteration % 1 == 0:
                logger.info(
                    f"- epoch: {epoch},  iteration: {iteration}, lr:{optimizer.param_groups[0]['lr']}, "
                    f"top1 acc: {acc1.item():.2f}% , "
                    f"loss: {loss.item():.3f}, (losses.avg): {losses.avg:3f} "
                )

5.每5个epoch保存一次

        if epoch != 0 and epoch % 5 == 0:
            print('Saving state, iter:', epoch)
            torch.save(model.state_dict(),
                       args.save_folder + '/' + args.dataset +
                       '_' + args.basenet + str(args.depth) + '_' + repr(epoch) + '.pth')
        torch.save(model.state_dict(),
                   args.save_folder + '/' + args.dataset + "_" + args.basenet + str(args.depth) + '.pth')

6.模型反向传播的更新(全部更新还是局部更新)

1)如果新的数据集是之前数据集的子集,就优化最后一层权重

optimizer = optim.AdamW(model.classifier.parameters(), lr=args.lr)

2)如果新的数据集没有包含在之前的数据集中,优化模型的权重

optimizer = optim.AdamW(model.parameters(), lr=args.lr)

其实本实验,如果不修改vgg16的模型选择1)也可以,但是为了研究改了模型选择2)

7.修改num_classes=2,并且删除top5

    parser.add_argument('--num_classes',
                        type=int,
                        default=2,
                        help='the number classes')

train.py完整代码

import os
import time
import logging
import argparse
import warnings

warnings.filterwarnings('ignore')

import torch
import torchvision
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from models.vgg16 import vgg16

from utils.path import *
from utils.dataloader import pets
from utils.accuracy import accuracy
from utils.collate import collate
from utils.get_logger import get_logger
from utils.AverageMeter import AverageMeter


def parse_args():
    parser = argparse.ArgumentParser(description='PyTorch Detection Training')
    parser.add_mutually_exclusive_group()
    parser.add_argument('--dataset',
                        type=str,
                        default='pets',
                        choices=['CIFAR10', 'pets'],
                        help=' CIFAR10 or pets')
    parser.add_argument('--dataset_root',
                        type=str,
                        default=PETS_path,
                        choices=[CIFAR_path, PETS_path],
                        help='Dataset root directory path')
    parser.add_argument('--basenet',
                        type=str,
                        default='vgg',
                        help='Pretrained base model')
    parser.add_argument('--depth',
                        type=int,
                        default=16,
                        help='Basenet depth')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='Batch size for training')
    parser.add_argument('--pretrained',
                        type=str,
                        default=True,
                        help='Checkpoint state_dict file to resume training from')
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='Number of workers user in dataloading')
    parser.add_argument('--cuda',
                        type=str,
                        default=True,
                        help='Use CUDA to train model')
    parser.add_argument('--save_folder',
                        type=str,
                        default=CheckPoints,
                        help='Directory for saving checkpoint models')
    parser.add_argument('--tensorboard',
                        type=str,
                        default=False,
                        help='Use tensorboard for loss visualization')
    parser.add_argument('--log_folder',
                        type=str,
                        default=log,
                        help='Log Folder')
    parser.add_argument('--log_name',
                        type=str,
                        default=classification_train_log,
                        help='Log Name')
    parser.add_argument('--tensorboard_log',
                        type=str,
                        default=tensorboard_log,
                        help='Use tensorboard for loss visualization')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-3,
                        help='learning rate')
    parser.add_argument('--epochs',
                        type=int,
                        default=50,
                        help='Number of epochs')
    parser.add_argument('--num_classes',
                        type=int,
                        default=2,
                        help='the number classes')
    parser.add_argument('--image_size',
                        type=int,
                        default=64,
                        help='image size')
    parser.add_argument('--accumulation_steps',
                        type=int,
                        default=1,
                        help='Gradient acumulation steps')
    parser.add_argument('--resume',
                        type=str,
                        default=None,
                        help='Checkpoint state_dict file to resume training from')

    return parser.parse_args()


args = parse_args()

# 1. Torch choose cuda or cpu
if torch.cuda.is_available():
    if args.cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    if not args.cuda:
        print("WARNING: It looks like you have a CUDA device, but you aren't using it" +
              "\n You can set the parameter of cuda to True.")
        torch.set_default_tensor_type('torch.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')

if os.path.exists(args.save_folder) is None:
    os.mkdir(args.save_folder)

# 2. Log
get_logger(args.log_folder, args.log_name)
logger = logging.getLogger(args.log_name)


def train():
    # 3. Create SummaryWriter
    if args.tensorboard:
        from torch.utils.tensorboard import SummaryWriter
        # tensorboard  loss
        writer = SummaryWriter(args.tensorboard_log)

    # 4. Ready dataset
    if args.dataset == 'CIFAR10':
        if args.dataset_root != CIFAR_path:
            raise ValueError('Must specify dataset_root if specifying dataset CIFAR10.')

        elif args.dataset_root is None:
            raise ValueError("Must provide --dataset_root when training on CIFAR10.")

        dataset = torchvision.datasets.CIFAR10(root=args.dataset_root, train=True,
                                               transform=torchvision.transforms.Compose([
                                                   transforms.Resize((args.image_size,
                                                                      args.image_size)),
                                                   torchvision.transforms.ToTensor()]))
    elif args.dataset == 'pets':
        if args.dataset_root != PETS_path:
            raise ValueError('Must specify dataset_root if specifying dataset PETS.')

        elif args.dataset_root is None:
            raise ValueError("Must provide --dataset_root when training on PETS.")

        dataset = pets(root_dir=args.dataset_root, transform=transforms.Compose([
            transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor()]))

    dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size=args.batch_size,
                                             num_workers=args.num_workers,
                                             shuffle=True,
                                             collate_fn=collate)
    top1 = AverageMeter()
    losses = AverageMeter()

    # 5. Define train model
    if args.basenet == 'vgg':
        if args.depth == 16:
            model = vgg16(num_classes=args.num_classes, pretrained=args.pretrained)
        else:
            raise ValueError("Unsupported model depth!")
    else:
        raise ValueError('Unsupported model type!')

    if args.cuda:
        if torch.cuda.is_available():
            model = model.cuda()
            model = torch.nn.DataParallel(model).cuda()
    else:
        model = torch.nn.DataParallel(model)

    # 6. Loading weights
    if args.resume:
        other, ext = os.path.splitext(args.resume)
        if ext == '.pkl' or '.pth':
            print('Loading weights into state dict...')
            model_load = os.path.join(args.save_folder, args.resume)
            model.load_state_dict(torch.load(model_load))
        else:
            print('Sorry only .pth and .pkl files supported.')
    elif args.resume is None:
        print("Initializing weights...")
        # initialize newly added models' weights with xavier method
        model.apply(weights_init)

    model.train()

    iteration = 0

    # 7. Optimizer
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss()

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=3, verbose=True)

    # 8. Length
    iter_size = len(dataset) // args.batch_size
    print("len(dataset): {}, iter_size: {}".format(len(dataset), iter_size))
    logger.info(f"args - {args}")
    t0 = time.time()

    # 9. Create batch iterator
    for epoch in range(args.epochs):
        t1 = time.time()
        torch.cuda.empty_cache()
        # 10. Load train data
        for data in dataloader:
            iteration += 1
            images, targets = data
            if args.cuda:
                images, targets = images.cuda(), targets.cuda()

            optimizer.zero_grad()
            # 11. Forward
            outputs = model(images)
            if args.cuda:
                criterion = criterion.cuda()

            loss = criterion(outputs, targets.long())
            loss = loss / args.accumulation_steps

            # 12. Backward
            loss.backward()
            optimizer.step()

            # 13. Measure accuracy and record loss
            acc1, _ = accuracy(outputs, targets, topk=(1, 1))
            top1.update(acc1.item(), images.size(0))
            losses.update(loss.item(), images.size(0))

            if args.tensorboard:
                writer.add_scalar("train_classification_loss", loss.item(), iteration)

            if iteration % 1 == 0:
                logger.info(
                    f"- epoch: {epoch},  iteration: {iteration}, lr:{optimizer.param_groups[0]['lr']}, "
                    f"top1 acc: {acc1.item():.2f}% , "
                    f"loss: {loss.item():.3f}, (losses.avg): {losses.avg:3f} "
                )

        scheduler.step(losses.avg)

        t2 = time.time()
        h_time = (t2 - t1) // 3600
        m_time = ((t2 - t1) % 3600) // 60
        s_time = ((t2 - t1) % 3600) % 60
        print("epoch {} is finished, and the time is {}h{}min{}s".format(epoch, int(h_time), int(m_time), int(s_time)))

        # 14. Save train model
        if epoch != 0 and epoch % 5 == 0:
            print('Saving state, iter:', epoch)
            torch.save(model.state_dict(),
                       args.save_folder + '/' + args.dataset +
                       '_' + args.basenet + str(args.depth) + '_' + repr(epoch) + '.pth')
        torch.save(model.state_dict(),
                   args.save_folder + '/' + args.dataset + "_" + args.basenet + str(args.depth) + '.pth')

    if args.tensorboard:
        writer.close()

    t3 = time.time()
    h = (t3 - t0) // 3600
    m = ((t3 - t0) % 3600) // 60
    s = ((t3 - t0) % 3600) % 60
    print("The Finished Time is {}h{}m{}s".format(int(h), int(m), int(s)))
    return top1.avg, losses.avg


def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias, 0.0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight)
        nn.init.constant_(m.bias, 0.0)


if __name__ == '__main__':
    torch.multiprocessing.set_start_method('spawn')
    logger.info("Program started")
    top1, loss = train()
    print("top1 acc: {}, loss:{}".format(top1, loss))
    logger.info("Done!")

运行结果

 

 

因为数据集太少了,最后的平均精度为50.77(之后可以扩充数据集,修改img_size等方式来提高精度!)

附加两个实验

1.将epoch设置为200,重新训练猫狗数据集(优化模型的所有权重)

得到精度最终能达到90左右,且平均精度达到70左右:

optimizer = optim.AdamW(model.parameters(), lr=args.lr)

 说明本实验对于优化模型所有权重的迁移学习成功!

运行结果

 

2.将epoch设置为200,重新训练猫狗数据集(优化模型最后一层权重)

得到精度最终能达到70左右,且平均精度达到67左右:

optimizer = optim.AdamW(model.classifier.parameters(), lr=args.lr)

说明本实验对于 模型最后一层权重 的 迁移学习 没有 使用模型所有权重 的 迁移学习 效果好!

运行结果

收工!

  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值